import os
import sys
import gzip
import tempfile
import shutil
import pybedtools
import pysam
import openpyxl
from Bio.Seq import reverse_complement
from numpy import *

dataset = sys.argv[1]
library = sys.argv[2]

assembly = "hg38"

skip_targets = ('chrM', 'rRNA', 'tRNA', 'snRNA', 'scRNA', 'snoRNA', 'yRNA',
                'scaRNA', 'snar', 'vRNA', 'RMRP', 'RPPH',
               )

keep_targets = ('histone', 'TERC', 'MALAT1', 'snhg',
                'mRNA', 'lncRNA', 'gencode', 'fantomcat', 'genome')


def read_primirna_annotations():
    directory = "/osc-fs_home/mdehoon/Data/miRBase/Release22"
    filename = "hsa.gff3"
    path = os.path.join(directory, filename)
    print("Reading", path)
    lines = pybedtools.BedTool(path)
    premirnas = {}
    for line in lines:
        fields = line.fields
        if fields[2] != 'miRNA_primary_transcript':
            continue
        name = line.attrs['Name']
        chromosome = line.chrom
        start = line.start
        end = line.end
        strand = line.strand
        premirnas[name] = (chromosome, strand, start, end)
    directory = "/osc-fs_home/mdehoon/Data/Fantom5/sRNA/Published"
    filename = "nbt.3947-S6.xlsx"
    path = os.path.join(directory, filename)
    print("Reading", path)
    workbook = openpyxl.load_workbook(path)
    assert len(workbook.sheetnames) == 1
    assert workbook.sheetnames[0] == 'TableS4'
    sheet = workbook.active
    assert sheet['A1'].value == "Name"
    assert sheet['B1'].value == "Set"
    assert sheet['C1'].value == "Sufficient tags?"
    assert sheet['D1'].value == "5' end consistent?"
    assert sheet['E1'].value == "3' end overhang?"
    assert sheet['F1'].value == "Suffient basepairing nucleotides?"
    assert sheet['G1'].value == "Suffient basepairing energy?"
    assert sheet['H1'].value == "Drosha CAGE peak p-value (FANTOM5)"
    assert sheet['I1'].value == "Drosha CAGE peak p-value (ENCODE)"
    robust = []
    for i in range(1, sheet.max_row):
        number = str(i+1)
        name = sheet['A'+number].value
        category = sheet['B'+number].value
        if category == 'robust':
            robust.append(name)
        elif category not in ('permissive', 'candidate'):
            raise Exception
    robust = set(robust)  # ignore for now
    filename = "nbt.3947-S19.xlsx"
    path = os.path.join(directory, filename)
    print("Reading", path)
    workbook = openpyxl.load_workbook(path)
    assert len(workbook.sheetnames) == 1
    assert workbook.sheetnames[0] == 'TableS17'
    sheet = workbook.active
    assert sheet['A1'].value == "Chromosome"
    assert sheet['B1'].value == "Strand"
    assert sheet['C1'].value == "TSS"
    assert sheet['D1'].value == "Promoter name"
    assert sheet['E1'].value == "Primary microRNA"
    assert sheet['F1'].value == "Intronic/intergenic"
    assert sheet['G1'].value == "Pre-microRNAs"
    assert sheet['H1'].value == "Conservation"
    assert sheet['I1'].value == "Maximum expression level (tpm)"
    assert sheet['J1'].value == "Highest expressing sample"
    assert sheet['K1'].value == "First most enriched in cell ontology cluster"
    assert sheet['L1'].value == "P-value"
    assert sheet['M1'].value == "log2(fold ratio)"
    assert sheet['N1'].value == "Second most enriched in cell ontology cluster"
    assert sheet['O1'].value == "P-value"
    assert sheet['P1'].value == "log2(fold ratio)"
    assert sheet['Q1'].value == "Third most enriched in cell ontology cluster"
    assert sheet['R1'].value == "P-value"
    assert sheet['S1'].value == "log2(fold ratio)"
    assert sheet['T1'].value == "First most depleted in cell ontology cluster"
    assert sheet['U1'].value == "P-value"
    assert sheet['V1'].value == "log2(fold ratio)"
    assert sheet['W1'].value == "Second most depleted in cell ontology cluster"
    assert sheet['X1'].value == "P-value"
    assert sheet['Y1'].value == "log2(fold ratio)"
    assert sheet['Z1'].value == "Third most depleted in cell ontology cluster"
    assert sheet['AA1'].value == "P-value"
    assert sheet['AB1'].value == "log2(fold ratio)"
    premirna_names = {}
    primirna_names = {}
    for i in range(1, sheet.max_row):
        number = str(i+1)
        promoter_name = sheet['D'+number].value
        primirna_name = sheet['E'+number].value
        premirna_names[promoter_name] = sheet['G'+number].value.split(",")
        primirna_names[promoter_name] = primirna_name
    directory = "/osc-fs_home/mdehoon/Data/Fantom5/Zoo/Published"
    filename = "Supplemental_Table_S12.xlsx"
    path = os.path.join(directory, filename)
    print("Reading", path)
    workbook = openpyxl.load_workbook(path)
    assert len(workbook.sheetnames) == 6
    table_name = 'TableS12a'
    assert workbook.sheetnames[1] == table_name
    sheet = workbook.get_sheet_by_name(table_name)
    assert sheet['A1'].value == "Primary miRNAs, human, after manual curation"
    assert sheet['D2'].value == "Chromosome (hg38)"
    assert sheet['E2'].value == "Strand (hg38)"
    assert sheet['F2'].value == "TSS (hg38)"
    assert sheet['G2'].value == "Promoter name"
    primirnas = []
    for i in range(2, sheet.max_row):
        number = str(i+1)
        chromosome = sheet['D'+number].value
        strand = sheet['E'+number].value
        tss = sheet['F'+number].value
        promoter_name = sheet['G'+number].value
        primirna_name = primirna_names[promoter_name]
        end = tss
        start = tss
        for premirna_name in premirna_names[promoter_name]:
            # if premirna_name not in robust:  # ignore for now
            if premirna_name.startswith("hsa-novelmir"):
                continue
            locus = premirnas.get(premirna_name)
            if locus is None:
                print("Skipping", premirna_name)
                continue
            assert locus[0] == chromosome
            if locus[1] != strand:
                print("Skipping", premirna_name)
                continue
            premirna_start, premirna_end = locus[2:]
            start = min(start, premirna_start)
            end = max(end, premirna_end)
        if start == end:
            continue
        fields = [chromosome, start, end, primirna_name, '.', strand]
        primirna = pybedtools.create_interval_from_list(fields)
        primirnas.append(primirna)
    print("Number of primirnas: %d" % len(primirnas))
    lines = pybedtools.BedTool(primirnas)
    return lines

def read_fantom5_enhancers():
    directory = "/osc-fs_home/mdehoon"
    subdirectory = "Data/Fantom5/Enhancers"
    filename = "F5.hg38.enhancers.bed.gz"
    # obtained from https://zenodo.org/record/556775 (contact: Robin Andersson)
    path = os.path.join(directory, subdirectory, filename)
    print("Reading", path)
    handle = gzip.open(path, 'rt')
    lines = pybedtools.BedTool(handle)
    for line in lines:
        yield pybedtools.create_interval_from_list(line.fields[:6])

def read_epigenome_roadmap(category):
    directory = "/osc-fs_home/mdehoon/Data/RoadmapEpigenomics/"
    if category == "roadmap_enhancer":
        filename = "enhancers.bed"
    elif category == "roadmap_dyadic":
        filename = "dyadic.bed"
    else:
        raise Exception("Unknown category %s" % category)
    path = os.path.join(directory, filename)
    print("Reading", path)
    handle = open(path)
    lines = pybedtools.BedTool(handle)
    for line in lines:
        chromosome = line.chrom
        start = line.start
        end = line.end
        name = "%s:%d-%d" % (chromosome, start, end)
        fields = line.fields + [name]
        interval = pybedtools.create_interval_from_list(fields)
        yield pybedtools.create_interval_from_list(fields)

def read_novel_enhancers(source):
    directory = "enhancer_predictions"
    filename = "novel_enhancers.%s.bed" % source
    path = os.path.join(directory, filename)
    handle = open(path)
    print("Reading", path)
    lines = pybedtools.BedTool(handle)
    for line in lines:
        fields = line.fields[:4]
        interval = pybedtools.create_interval_from_list(fields)
        yield pybedtools.create_interval_from_list(fields)

def read_last_exon():
    path = "/osc-fs_home/mdehoon/Data/NCBI/hg38/exons.gff"
    handle = open(path)
    lines = pybedtools.BedTool(handle)
    for line in lines:
        terms = line.fields[2].split(':')
        counter = int(terms[0])
        geneid = int(terms[1])
        genename = terms[2]
        transcript = terms[3]
        exon_number = int(terms[4])
        exons_total = int(terms[5])
        assert exon_number >= 0
        assert exon_number < exons_total
        strand = line.strand
        if strand == "+":
            if exon_number != exons_total - 1:
                continue
        elif strand == "-":
            if exon_number != 0:
                continue
        else:
            raise Exception("Unknown strand %s" % line.strand)
        chromosome = line.chrom
        start = line.start
        end = line.end
        fields = [chromosome, start, end, transcript, "0", strand]
        yield pybedtools.create_interval_from_list(fields)

def read_annotations(category):
    if category == 'FANTOM5_enhancer':
        annotations = read_fantom5_enhancers()
    elif category in ('roadmap_enhancer', 'roadmap_dyadic'):
        annotations = read_epigenome_roadmap(category)
    elif category == "last_exon":
        annotations = read_last_exon()
    elif category == 'novel_enhancer_HiSeq':
        annotations = read_novel_enhancers("HiSeq")
    elif category == 'novel_enhancer_CAGE':
        annotations = read_novel_enhancers("CAGE")
    else:
        raise Exception("Unknown gene category %s" % category)
    yield from annotations

def parse_bamfile(path):
    print("Reading", path)
    lines = pysam.Samfile(path)
    current = None
    for line in lines:
        if line.is_unmapped:
            continue
        target = line.get_tag("XT")
        if target in skip_targets:
            continue
        assert target in keep_targets
        try:
            value = line.get_tag("XA")
        except KeyError:
            pass
        else:
            continue
        chromosome = line.reference_name
        start = line.pos
        end = line.aend
        # We are annotating enhancer RNAs and transcripts in the last exon.
        # This should not include any spliced transcripts.
        try:
            length = line.get_tag("XL")
        except KeyError:
            pass
        else:
            if length < end - start:
                continue
        name = line.query_name
        if line.is_reverse:
            strand = "-"
        else:
            strand = "+"
        number = line.get_tag("HI")
        score = str(number)
        fields = [chromosome, start, end, name, score, strand]
        interval = pybedtools.create_interval_from_list(fields)
        yield interval

def parse_miseq_bamfile(path):
    print("Reading", path)
    lines = pysam.Samfile(path)
    current = None
    for line1 in lines:
        line2 = next(lines)
        if line1.is_unmapped:
            assert line2.is_unmapped
            continue
        start1 = line1.reference_start
        end1 = line1.reference_end
        start2 = line2.reference_start
        end2 = line2.reference_end
        assert start1 < end1
        assert start2 < end2
        if line1.is_reverse:
            assert not line2.is_reverse
            start = start2
            end = end1
        else:
            if not line2.is_reverse:
                print(line1)
                print(line2)
            assert line2.is_reverse
            start = start1
            end = end2
        target = line1.get_tag("XT")
        if target in skip_targets:
            continue
        assert target in keep_targets
        try:
            value = line1.get_tag("XA")
        except KeyError:
            pass
        else:
            continue
        chromosome = line1.reference_name
        # We are annotating enhancer RNAs and transcripts in the last exon.
        # This should not include any spliced transcripts.
        try:
            length = line1.get_tag("XL")
        except KeyError:
            pass
        else:
            if length < end - start:
                continue
        name = line1.query_name
        if line1.is_reverse:
            strand = "-"
        else:
            strand = "+"
        number = line1.get_tag("HI")
        score = str(number)
        fields = [chromosome, start, end, name, score, strand]
        interval = pybedtools.create_interval_from_list(fields)
        yield interval

def write_alignments(dataset, alignments, sequence1, sequence2):
    number = 0
    total = len(alignments)
    if dataset == 'MiSeq':
        assert total % 2 == 0
        total //= 2
    alignments = iter(alignments)
    for alignment1 in alignments:
        if alignment1.is_unmapped:
            assert number == 0
            assert total == 1
        else:
            alignment1.set_tag("HI", number)
            alignment1.set_tag("NH", total)
        if number == 0:
            if alignment1.is_reverse:
                sequence1 = reverse_complement(sequence1)
            alignment1.query_sequence = sequence1
            alignment1.is_secondary = False
        else:
            alignment1.is_secondary = True
        output.write(alignment1)
        if dataset == 'MiSeq':
            alignment2 = next(alignments)
            if alignment2.is_reverse:
                sequence2 = reverse_complement(sequence2)
            alignment2.query_sequence = sequence2
            if number == 0:
                alignment2.is_secondary = False
            else:
                alignment2.is_secondary = True
            output.write(alignment2)
        number += 1

chrom_sizes_path = "/osc-fs_home/scratch/mdehoon/Data/Genomes/hg38/hg38.chrom.sizes"
categories = (
              "last_exon",
              "FANTOM5_enhancer",
              "roadmap_enhancer",
              "roadmap_dyadic",
              "novel_enhancer_HiSeq",
              "novel_enhancer_CAGE",
             )

filename = "%s.bam" % library
directory = "/osc-fs_home/mdehoon/Data/CASPARs/"
path = os.path.join(directory, dataset, "Mapping", filename)
for category in categories:
    if dataset == "MiSeq":
        alignments = parse_miseq_bamfile(path)
    else:
        alignments = parse_bamfile(path)
        sequence2 = None
    alignments = pybedtools.BedTool(alignments)
    alignments = alignments.sort(g=chrom_sizes_path)
    associations = {}
    annotations = read_annotations(category)
    annotations = pybedtools.BedTool(annotations)
    annotations = annotations.sort(g=chrom_sizes_path)
    if category == "last_exon":
        # require same strand
        overlap = alignments.intersect(annotations, wa=True, wb=True, s=True)
    else:
        overlap = alignments.intersect(annotations, wb=True)
    for line in overlap:
        fields = line.fields
        assert len(fields) == 12
        alignment = pybedtools.create_interval_from_list(fields[:6])
        annotation = pybedtools.create_interval_from_list(fields[6:])
        if category == "last_exon":
            if alignment.strand == "+":
                if alignment.start < annotation.start:
                    continue
            elif alignment.strand == "-":
                if alignment.end > annotation.end:
                    continue
        name = alignment.name
        number = int(alignment.score)
        if name not in associations:
            associations[name] = {}
        associations[name][number] = annotation.name
    print("%s: found %d new annotations" % (category, len(associations)))
    print("Reading %s" % path)
    alignments = pysam.Samfile(path)
    current = ""
    stream = tempfile.NamedTemporaryFile(delete=False)
    stream.close()
    print("Writing %s" % stream.name)
    output = pysam.Samfile(stream.name, "wb", template=alignments)
    skip = 0
    for alignment1 in alignments:
        name = alignment1.query_name
        if dataset == "MiSeq":
            alignment2 = next(alignments)
            assert alignment2.query_name == name
            assert alignment1.is_read1
            assert not alignment1.is_read2
            assert not alignment2.is_read1
            assert alignment2.is_read2
        if name == current:
            number += 1
        else:
            if current:
                write_alignments(dataset, selected_alignments, sequence1, sequence2)
            current = name
            number = 0
            selected_alignments = []
            if alignment1.is_reverse:
                sequence1 = reverse_complement(alignment1.query_sequence)
            else:
                sequence1 = alignment1.query_sequence
            if dataset == "MiSeq":
                if alignment2.is_reverse:
                    sequence2 = reverse_complement(alignment2.query_sequence)
                else:
                    sequence2 = alignment2.query_sequence
        if alignment1.is_unmapped:
            if dataset in "MiSeq":
                assert alignment2.is_unmapped
        else:
            assert alignment1.get_tag("HI") == number
            target = alignment1.get_tag("XT")
            if target in keep_targets:
                current_associations = associations.get(name)
                if current_associations is not None:
                    annotation = current_associations.get(number)
                    if annotation is None:
                        # other mapping locations of this read are annotated,
                        # but the current mapping location is not annotated
                        skip += 1
                        continue
                    assert category != annotation
                    if category == "last_exon":
                        try:
                            existing = alignment1.get_tag("XE")
                        except KeyError:
                            pass
                        else:
                            raise Exception("spliced transcripts should not be annotated as last_exon")
                        alignment1.set_tag("XE", category)
                        # annotation is gene name
                        assert annotation != "."
                        try:
                            existing = alignment1.get_tag("XF")
                        except KeyError:
                            pass
                        else:
                            raise Exception("found existing gene tag XF (%s, %s)" % (existing, annotation))
                        alignment1.set_tag("XF", annotation)
                    else:
                        try:
                            alignment1.get_tag("XA")
                        except KeyError:
                            pass
                        else:
                            raise Exception("found existing annotation tag XA")
                        alignment1.set_tag("XA", category)
                        assert category in ("FANTOM5_enhancer",
                                            "roadmap_enhancer",
                                            "roadmap_dyadic",
                                            "novel_enhancer_HiSeq",
                                            "novel_enhancer_CAGE")
                        # annotation is enhancer locus name
                        assert annotation != "."
                        try:
                            existing = alignment1.get_tag("XG")
                        except KeyError:
                            pass
                        else:
                            raise Exception("%s: found existing gene tag XG (%s, %s)" % (alignment1.query_name, existing, annotation))
                        alignment1.set_tag("XG", annotation)
            else:
                assert target in skip_targets
        selected_alignments.append(alignment1)
        if dataset == "MiSeq":
            selected_alignments.append(alignment2)
    if current:
        write_alignments(dataset, selected_alignments, sequence1, sequence2)
    output.close()
    alignments.close()
    print("Number of removed lines: %d" % skip)
    print("Moving %s to %s" % (stream.name, filename))
    shutil.move(stream.name, filename)
    path = filename
